#!/usr/bin/python3
# -*- coding: utf-8 -*-
import os
import getpass
import argparse
import threading
import queue
import json
import re
import subprocess
from pysnmp.hlapi import *

import AutoExecUtils


class SnmpCollectWorker(threading.Thread):
    # snmpwalk -v3 -l authPriv -u monitor -a sha -A AuthPassword -x AES -X DataPassword 192.168.0.1 SNMPv2-MIB::sysDescr.0
    def __init__(self, name, execQueue, snmpPort, db, pluginDir, communities, unknownItems):
        threading.Thread.__init__(self, name=name, daemon=True)
        self.goToStop = False
        self._queue = execQueue
        self.snmpPort = snmpPort
        self.ruleCollection = db["_discovery_rule"]
        self.errorCount = 0
        self.pluginDir = pluginDir
        self.communities = communities
        self.unknownItems = unknownItems

    def getObjCatByRule(self, objInfo):
        matched = False
        sysObjId = objInfo.get("SYS_OBJECT_ID")
        sysDescr = objInfo.get("SYS_DESCR")
        collection = self.ruleCollection
        for rule in collection.find({"sysObjectId": sysObjId}, {"_id": False}):
            descrPattern = rule.get("sysDescrPattern")
            if descrPattern is None or re.match(descrPattern, sysDescr):
                matched = True
                objInfo["_OBJ_CATEGORY"] = rule["_OBJ_CATEGORY"]
                objInfo["_OBJ_TYPE"] = rule["_OBJ_TYPE"]
                objInfo["VENDOR"] = rule["VENDOR"]
                objInfo["MODEL"] = rule["MODEL"]
                break

        return matched

    def getSnmpV3UsmUserData(self, authDict):
        # {"userName": "monitor", "authKey": "authPasssword", "authProtocol": "md5", "privKey": "dataEncryptPassword", "privProtocol": "aes"}
        usmUserData = UsmUserData(authDict.get("userName", None), authKey=authDict.get("authPassword", None), privKey=authDict.get("privPassword", None))

        authProcol = authDict.get("authProtocol", "").upper()

        if authProcol == "":
            usmUserData.authProtocol = auth.usmNoAuthProtocol
        elif authProcol in ("MD5", "HMAC-MD5"):
            usmUserData.authProtocol = auth.usmHMACMD5AuthProtocol
        elif authProcol in ("SHA", "SHA1", "SHA-1"):
            usmUserData.authProtocol = auth.usmHMACSHAAuthProtocol
        elif authProcol in ("SHA2", "SHA224", "SHA-224"):
            usmUserData.authProtocol = auth.usmHMAC128SHA224AuthProtocol
        elif authProcol in ("SHA256", "SHA-256"):
            usmUserData.authProtocol = auth.usmHMAC192SHA256AuthProtocol
        elif authProcol in ("SHA384", "SHA-384"):
            usmUserData.authProtocol = auth.usmHMAC256SHA384AuthProtocol
        elif authProcol in ("SHA512", "SHA-256"):
            usmUserData.authProtocol = auth.usmHMAC384SHA512AuthProtocol

        privProtocol = authDict.get("privProtocol", "").upper()
        if privProtocol == "":
            usmUserData.privProtocol = auth.usmNoPrivProtocol
        elif privProtocol in ("AES", "AES128", "AES-128"):
            usmUserData.privProtocol = auth.usmAesCfb128Protocol
        elif privProtocol in ("AES192", "AES-192"):
            usmUserData.privProtocol = auth.usmAesCfb192Protocol
        elif privProtocol in ("AES256", "AES-256"):
            usmUserData.privProtocol = auth.usmAesCfb256Protocol
        elif privProtocol in ("DES"):
            usmUserData.privProtocol = auth.usmDESPrivProtocol
        elif privProtocol in ("3DES"):
            usmUserData.privProtocol = auth.usm3DESEDEPrivProtocol

        return usmUserData

    def snmpGetSysInfo(self, objInfo, communities):
        workCommunity = None
        snmpQuery = False
        ip = objInfo.get("MGMT_IP")
        for community in communities:
            authData = None
            if isinstance(community, dict):
                authData = self.getSnmpV3UsmUserData(community)
            else:
                authData = CommunityData(community)

            iterator = getCmd(
                SnmpEngine(),
                authData,
                UdpTransportTarget((ip, self.snmpPort)),
                ContextData(),
                ObjectType(ObjectIdentity("1.3.6.1.2.1.1.2.0")),
                ObjectType(ObjectIdentity("1.3.6.1.2.1.1.1.0")),
                ObjectType(ObjectIdentity("1.3.6.1.2.1.1.5.0")),
            )

            errorIndication, errorStatus, errorIndex, varBinds = next(iterator)

            if errorIndication:
                print(errorIndication)

            elif errorStatus:
                print(
                    "%s at %s"
                    % (
                        errorStatus.prettyPrint(),
                        errorIndex and varBinds[int(errorIndex) - 1][0] or "?",
                    )
                )

            else:
                for varBind in varBinds:
                    oid = str(varBind[0])
                    oidVal = str(varBind[1])
                    if oid == "1.3.6.1.2.1.1.2.0":
                        objInfo["SYS_OBJECT_ID"] = "." + oidVal
                        snmpQuery = True
                    elif oid == "1.3.6.1.2.1.1.1.0":
                        objInfo["SYS_DESCR"] = oidVal
                    elif oid == "1.3.6.1.2.1.1.5.0":
                        objInfo["DEV_NAME"] = oidVal

            if snmpQuery:
                print("INFO: Snmp get for {} success.\n".format(ip), end="")
                if self.getObjCatByRule(objInfo):
                    print(
                        "INFO: Device regconized rule matched for {}.\n".format(ip),
                        end="",
                    )
                    workCommunity = community
                else:
                    print(
                        "INFO: There no device regconized rule matched for {}.\n".format(ip),
                        end="",
                    )

        return workCommunity

    def getCollectCmd(self, objInfo, nodeName, community):
        cmd = None
        objCat = objInfo.get("_OBJ_CATEGORY")
        objType = objInfo.get("_OBJ_TYPE")

        nodeInfo = {
            "resourceId": 0,
            "nodeName": nodeName,
            "host": objInfo.get("MGMT_IP"),
            "port": None,
            "protocol": "snmp",
            "protocolPort": 161,
            "username": "none",
            # "password": community,
            "nodeType": objType,
        }

        if isinstance(community, dict):
            nodeInfo["password"] = community.get("authPassword")
            nodeInfo["authData"] = json.dumps(community, ensure_ascii=False)
        else:
            nodeInfo["password"] = community

        nodeJsonStr = json.dumps(nodeInfo, ensure_ascii=False)

        if objCat == "SWITCH":
            cmd = "{}/switchcollector --objtype {}".format(self.pluginDir, objType)
        elif objCat == "LOADBALANCER":
            if objType == "F5":
                cmd = "{}/f5collector".format(self.pluginDir)
            elif objType == "A10":
                cmd = "{}/a10collector".format(self.pluginDir)
        elif objCat == "SECDEV":
            if objType == "FireWall":
                cmd = "{}/firewallcollector --type auto".format(self.pluginDir)
        elif objCat == "FCDEV":
            if objCat == "FCSwitch ":
                cmd = "{}/storagecollector --type auto".format(self.pluginDir)

        if cmd is not None:
            cmd = "{} --node '{}'".format(cmd, nodeJsonStr)
            ip = objInfo["MGMT_IP"]
            if not os.path.exists(ip):
                os.mkdir(ip)

            cmd = "cd '{}' && {}".format(ip, cmd)

        return cmd

    def saveOneNode(self, ip):
        saveCmd = "{}/savedata --outputfile '{}/output.json'".format(self.pluginDir, ip)
        ret = os.system(saveCmd)
        return ret

    def collect(self, objInfo):
        hasError = 0

        detected = False
        workCommunity = None
        # Detech objType by snmp sysObjectId and sysDescr, sysName

        workCommunity = self.snmpGetSysInfo(objInfo, self.communities)
        if workCommunity is None:
            # snmp获取sysObjectId并判断类型失败
            hasError = hasError + 1
            objInfo["_OBJ_CATEGORY"] = "UNKNOWN"
            self.unknownItems.append(objInfo)

        if hasError == 0:
            ret = 0
            # Call collect tool
            collectCmd = self.getCollectCmd(objInfo, objInfo.get("DEV_NAME"), workCommunity)
            if collectCmd is not None:
                print(
                    "INFO: Collect information for " + objInfo["MGMT_IP"] + ".\n",
                    end="",
                )
                ret = os.system(collectCmd)
                hasError = hasError + ret
            else:
                hasError = hasError + 1
            if hasError != 0:
                self.unknownItems.append(objInfo)

        # call savedata
        if hasError == 0:
            # print(json.dumps(objInfo, indent=4, ensure_ascii=False), end='')
            print(
                "INFO: Object catetory:{} object type:{} IP:{} collected.\n".format(
                    objInfo.get("_OBJ_CATEGORY"),
                    objInfo.get("_OBJ_TYPE"),
                    objInfo.get("MGMT_IP"),
                ),
                end="",
            )
            hasError = hasError + self.saveOneNode(objInfo.get("MGMT_IP"))

    def run(self):
        while not self.goToStop:
            objInfo = self._queue.get()
            if objInfo is None:
                break
            self.collect(objInfo)


class NmapScan:
    def __init__(self):
        pass

    def parse(self, hostInfo):
        objInfo = {
            "MGMT_IP": hostInfo["IP"],
            "_OBJ_CATETORY": "UNKNOWN",
            "_OBJ_TYPE": "UNKNOWN",
        }

        OSDesc = hostInfo.get("OS")

        if OSDesc is None:
            return objInfo

        typesDef = [
            {"key": "Linux", "_OBJ_CATEGORY": "OS", "_OBJ_TYPE": "Linux"},
            {"key": "Windows", "_OBJ_CATEGORY": "OS", "_OBJ_TYPE": "Windows"},
            {"key": "AIX", "_OBJ_CATEGORY": "OS", "_OBJ_TYPE": "AIX"},
            {"key": "SunOS", "_OBJ_CATEGORY": "OS", "_OBJ_TYPE": "SunOS"},
            {"key": "FreeBSD", "_OBJ_CATEGORY": "OS", "_OBJ_TYPE": "FreeBSD"},
            {"key": "CISCO", "_OBJ_CATEGORY": "SWITCH", "_OBJ_TYPE": None},
            {"key": "Huawei", "_OBJ_CATEGORY": "SWITCH", "_OBJ_TYPE": None},
            {"key": "Juniper", "_OBJ_CATEGORY": "SWITCH", "_OBJ_TYPE": None},
            {"key": "Ruijie", "_OBJ_CATEGORY": "SWITCH", "_OBJ_TYPE": None},
            {"key": "H3C", "_OBJ_CATEGORY": "SWITCH", "_OBJ_TYPE": None},
            {
                "pattern": re.compile(r"Google Android"),
                "_OBJ_CATEGORY": "MOBILE",
                "_OBJ_TYPE": "Android",
            },
        ]

        for matchItem in typesDef:
            key = matchItem.get("key")
            pattern = matchItem.get("pattern")

            if key is not None:
                try:
                    OSDesc.index(key)
                    objInfo["_OBJ_CATEGORY"] = matchItem["_OBJ_CATEGORY"]
                    objInfo["_OBJ_TYPE"] = matchItem["_OBJ_TYPE"]
                    break
                except ValueError:
                    pass
            elif pattern is not None:
                if re.search(pattern, OSDesc):
                    objInfo["_OBJ_CATEGORY"] = matchItem["_OBJ_CATEGORY"]
                    objInfo["_OBJ_TYPE"] = matchItem["_OBJ_TYPE"]
                    break

        return objInfo

    def scan(self, net=None, ports=None, timingTmpl=5):
        if ports is None:
            ports = "22,161,135,139,445,3389"
        # nmap -oG - 192.168.0.1/24,192.168.1.1/24 -p 22,161,135,139,445,3389 -T5 -sSU --top-ports 100 -n -O > /tmp/nmap.txt
        nmapCmd = "sudo nmap -oG - {} -p {} -T{} -O -n -sSU --top-ports 5".format(net, ports, timingTmpl)
        child = subprocess.Popen(
            nmapCmd,
            shell=True,
            close_fds=True,
            stdin=subprocess.PIPE,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
        )

        result = {}
        while True:
            # readline 增加maxSize参数是为了防止行过长，pipe buffer满了，行没结束，导致pipe写入阻塞
            line = child.stdout.readline(4096)
            if not line:
                break

            line = line.strip()
            line = line.decode("utf-8")
            if line != "" and not line.startswith("#") and not line.startswith("RTTVAR has grown to over"):
                hostInfo = None
                for field in line.split("\t"):
                    name = None
                    val = None
                    try:
                        (name, val) = field.split(": ", 2)
                    except:
                        print(line)

                    if name == "Host":
                        (ip, dns) = val.split(" ")
                        ip = ip.strip()
                        hostInfo = result.get(ip)
                        if hostInfo is None:
                            hostInfo = {}
                            result[ip] = hostInfo

                        hostInfo["IP"] = ip
                        hostInfo["Name"] = dns[1:-1]
                    elif name == "Ports":
                        for portSeg in val.split("/, "):
                            portsInfo = hostInfo.get("Ports")
                            if portsInfo is None:
                                portsInfo = {}
                                hostInfo["Ports"] = portsInfo

                            # (port, state, protocol, owner, service, rpc_info, version)
                            portParts = portSeg.split("/")
                            if portParts[1] == "open":
                                portsInfo[portParts[0]] = {
                                    "port": portParts[0],
                                    "state": portParts[1],
                                    "protocol": portParts[2],
                                    "owner": portParts[3],
                                    "service": portParts[4],
                                    "rpc_info": portParts[5],
                                    "version": portParts[6],
                                }
                    elif name == "Protocols":
                        for protocolSeg in val.split("/, "):
                            protocolsInfo = hostInfo.get("Protocols")
                            if protocolsInfo is None:
                                protocolsInfo = {}
                                hostInfo["Protocols"] = protocolsInfo

                            # (port, state, protocol, owner, service, rpc_info, version)
                            protocolParts = protocolSeg.split("/")
                            protocolsInfo[portParts[2]] = {
                                "number": portParts[0],
                                "state": portParts[1],
                                "name": portParts[2],
                            }
                    elif name == "Status":
                        hostInfo["Status"] = val.strip()
                    else:
                        if hostInfo is None:
                            user = getpass.getuser()
                            print("WARN: Execute nmap failed, check the sudoers setting for user %s to execute command nmap." % (user))
                        hostInfo[name] = val

        return result


def buildWorkerPool(workerCount, execQueue, snmpPort, db, pluginDir, communities, unknownItems):
    workers = []
    for i in range(workerCount):
        worker = SnmpCollectWorker(
            "Worker-{}".format(i),
            execQueue,
            snmpPort,
            db,
            pluginDir,
            communities,
            unknownItems,
        )
        worker.setDaemon(True)
        worker.start()
        workers.append(worker)

    return workers


def discovery(snmpPort, workerCount, net, ports, communities, timingTmpl):
    binDirs = os.path.split(os.path.realpath(__file__))
    pluginDir = os.path.realpath(binDirs[0])

    nmapScan = NmapScan()
    result = nmapScan.scan(net=net, ports=ports, timingTmpl=timingTmpl)

    (dbclient, db) = AutoExecUtils.getDB()

    otherItems = []

    execQueue = queue.Queue(workerCount * 2)
    workerThreads = buildWorkerPool(workerCount, execQueue, snmpPort, db, pluginDir, communities, otherItems)

    hasError = 0
    try:
        for ip, hostInfo in result.items():
            snmpPortStr = str(snmpPort)
            portsInfo = hostInfo.get("Ports", {})

            objInfo = nmapScan.parse(hostInfo)
            if objInfo.get("_OBJ_CATEGORY") == "OS":
                otherItems.append(objInfo)
            elif portsInfo.get(snmpPortStr) is not None:
                execQueue.put(objInfo)
            else:
                otherItems.append(objInfo)
    finally:
        # 入队对应线程数量的退出信号对象
        for idx in range(1, workerCount * 2):
            execQueue.put(None)
        # 等待所有worker线程退出
        while len(workerThreads) > 0:
            worker = workerThreads[-1]
            hasError = hasError + worker.errorCount
            worker.join(3)
            if not worker.is_alive():
                workerThreads.pop(-1)

        if dbclient is not None:
            dbclient.close()

    if len(otherItems) > 0:
        out = {"DATA": otherItems}
        AutoExecUtils.saveOutput(out)
        saveCmd = "{}/savedata --outputfile output.json".format(pluginDir)
        ret = os.system(saveCmd)
        hasError = hasError + ret

    return hasError


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--nets",
        default="192.168.0.0/24",
        help="Discovery nets, example:192.168.0.0/24,192.168.1.0/24",
    )
    parser.add_argument(
        "--ports",
        default="22,161,135,139,445,3389,3939",
        help="Scan ports, default:22,161,135,139,445,3389,3939",
    )
    parser.add_argument("--snmpport", default=161, help="Snmp Port")
    parser.add_argument(
        "--communities",
        default='["public"]',
        help='Snmp Communities(JSON Array), example:snmpv2:["public","mary"]或snmpv3:[{"userName": "monitor", "authKey": "AuthPasssword", "authProtocol": "md5", "privKey": "EncryptDataPassword", "privProtocol": "aes"}]',
    )
    parser.add_argument("--workercount", default=16, help="Worker thread counts")
    parser.add_argument("--timingtmpl", default=4, help="Timing template, 1-5, 5 is fastest")

    args = parser.parse_args()

    snmpPort = int(args.snmpport)
    workerCount = int(args.workercount)
    timingTmpl = int(args.timingtmpl)

    ports = args.ports

    isSnmpV3 = False
    communities = ["public"]
    try:
        json.loads(args.communities)
    except:
        print('WARN: Param communities is not in json format, example:["public",{"username":"monitor","authPassword":"AuthPassword","authProtocol":"sha","privPassword":"DataEncryptPassword","privProtocol":"aes"}]')

    autoexecHome = os.environ.get("AUTOEXEC_HOME")
    os.environ["PATH"] = "%s:%s/tools" % (os.environ["PATH"], autoexecHome)
    os.environ["OUTPUT_PATH"] = "output.json"
    os.chdir(os.environ.get("AUTOEXEC_WORK_PATH"))

    for net in args.nets.split(","):
        discovery(snmpPort, workerCount, net, ports, communities, timingTmpl)
